import numpy as np
import torch.nn as nn

from .anchor_head_template import AnchorHeadTemplate

class AnchorHeadSingle(AnchorHeadTemplate):
    def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
                 predict_boxes_when_training=True):
        super().__init__(
            model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size, point_cloud_range=point_cloud_range,
            predict_boxes_when_training=predict_boxes_when_training
        )

        self.num_anchors_per_location = sum(self.num_anchors_per_location)

        self.conv_cls = nn.Conv2d(
            input_channels, self.num_anchors_per_location * self.num_class,
            kernel_size=1
        )
        self.conv_box = nn.Conv2d(
            input_channels, self.num_anchors_per_location * self.box_coder.code_size,
            kernel_size=1
        )

        if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', None) is not None:
            self.conv_dir_cls = nn.Conv2d(
                input_channels,
                self.num_anchors_per_location * self.model_cfg.NUM_DIR_BINS,
                kernel_size=1
            )
        else:
            self.conv_dir_cls = None
        self.init_weights()

    def init_weights(self):
        pi = 0.01
        nn.init.constant_(self.conv_cls.bias, -np.log((1 - pi) / pi))
        nn.init.normal_(self.conv_box.weight, mean=0, std=0.001)

    def forward(self, data_dict):
        if self.training:
            spatial_features_2d = data_dict['spatial_features_2d']
            #####
            spatial_features_point_2d = data_dict['spatial_features_point_2d']
            #######
            cls_preds = self.conv_cls(spatial_features_2d)
            box_preds = self.conv_box(spatial_features_2d)

            cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous()  # [N, H, W, C]
            box_preds = box_preds.permute(0, 2, 3, 1).contiguous()  # [N, H, W, C]

            self.forward_ret_dict['cls_preds'] = cls_preds
            self.forward_ret_dict['box_preds'] = box_preds

            #######
            cls_preds_point = self.conv_cls(spatial_features_point_2d)
            box_preds_point = self.conv_box(spatial_features_point_2d)

            cls_preds_point = cls_preds_point.permute(0, 2, 3, 1).contiguous()  # [N, H, W, C]
            box_preds_point = box_preds_point.permute(0, 2, 3, 1).contiguous()  # [N, H, W, C]
            
            self.forward_ret_dict['cls_preds_point'] = cls_preds_point
            self.forward_ret_dict['box_preds_point'] = box_preds_point

            ###  for feature sim loss
    
            self.forward_ret_dict['pos_point_feas'] = data_dict['point_positive_features']
            self.forward_ret_dict['pos_memory_feas'] = data_dict['memory_positive_features']
            self.forward_ret_dict['memory_items'] = data_dict['memory_items']
            #########

            if self.conv_dir_cls is not None:
                dir_cls_preds = self.conv_dir_cls(spatial_features_2d)
                dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
                self.forward_ret_dict['dir_cls_preds'] = dir_cls_preds
                #####
                dir_cls_preds_point = self.conv_dir_cls(spatial_features_point_2d)
                dir_cls_preds_point = dir_cls_preds_point.permute(0, 2, 3, 1).contiguous()
                self.forward_ret_dict['dir_cls_preds_point'] = dir_cls_preds_point
                ######
            else:
                dir_cls_preds = None

            if self.training:
                targets_dict = self.assign_targets(
                    gt_boxes=data_dict['gt_boxes']
                )
                self.forward_ret_dict.update(targets_dict)

            if not self.training or self.predict_boxes_when_training:
                batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
                    batch_size=data_dict['batch_size'],
                    cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
                )
                data_dict['batch_cls_preds'] = batch_cls_preds
                data_dict['batch_box_preds'] = batch_box_preds
                data_dict['cls_preds_normalized'] = False
                ##########
                batch_cls_preds_point, batch_box_preds_point = self.generate_predicted_boxes(
                    batch_size=data_dict['batch_size'],
                    cls_preds=cls_preds_point, box_preds=box_preds_point, dir_cls_preds=dir_cls_preds_point
                )
                data_dict['batch_cls_preds'] = batch_cls_preds
                data_dict['batch_box_preds'] = batch_box_preds
                data_dict['cls_preds_normalized'] = False
                ##################

            return data_dict
        else:
            spatial_features_2d = data_dict['spatial_features_2d']
        
            cls_preds = self.conv_cls(spatial_features_2d)
            box_preds = self.conv_box(spatial_features_2d)

            cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous()  # [N, H, W, C]
            box_preds = box_preds.permute(0, 2, 3, 1).contiguous()  # [N, H, W, C]

            self.forward_ret_dict['cls_preds'] = cls_preds
            self.forward_ret_dict['box_preds'] = box_preds

            if self.conv_dir_cls is not None:
                dir_cls_preds = self.conv_dir_cls(spatial_features_2d)
                dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
                self.forward_ret_dict['dir_cls_preds'] = dir_cls_preds
           
            else:
                dir_cls_preds = None

            if self.training:
                targets_dict = self.assign_targets(
                    gt_boxes=data_dict['gt_boxes']
                )
                self.forward_ret_dict.update(targets_dict)

            if not self.training or self.predict_boxes_when_training:
                batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
                    batch_size=data_dict['batch_size'],
                    cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
                )
                data_dict['batch_cls_preds'] = batch_cls_preds
                data_dict['batch_box_preds'] = batch_box_preds
                data_dict['cls_preds_normalized'] = False
           

            return data_dict

# class AnchorHeadSingle(AnchorHeadTemplate):
#     def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
#                  predict_boxes_when_training=True):
#         super().__init__(
#             model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size, point_cloud_range=point_cloud_range,
#             predict_boxes_when_training=predict_boxes_when_training
#         )

#         self.num_anchors_per_location = sum(self.num_anchors_per_location)

#         self.conv_cls = nn.Conv2d(
#             input_channels, self.num_anchors_per_location * self.num_class,
#             kernel_size=1
#         )
#         self.conv_box = nn.Conv2d(
#             input_channels, self.num_anchors_per_location * self.box_coder.code_size,
#             kernel_size=1
#         )

#         if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', None) is not None:
#             self.conv_dir_cls = nn.Conv2d(
#                 input_channels,
#                 self.num_anchors_per_location * self.model_cfg.NUM_DIR_BINS,
#                 kernel_size=1
#             )
#         else:
#             self.conv_dir_cls = None
#         self.init_weights()

#     def init_weights(self):
#         pi = 0.01
#         nn.init.constant_(self.conv_cls.bias, -np.log((1 - pi) / pi))
#         nn.init.normal_(self.conv_box.weight, mean=0, std=0.001)

#     def forward(self, data_dict):
#         spatial_features_2d = data_dict['spatial_features_2d']

#         cls_preds = self.conv_cls(spatial_features_2d)
#         box_preds = self.conv_box(spatial_features_2d)

#         cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous()  # [N, H, W, C]
#         box_preds = box_preds.permute(0, 2, 3, 1).contiguous()  # [N, H, W, C]

#         self.forward_ret_dict['cls_preds'] = cls_preds
#         self.forward_ret_dict['box_preds'] = box_preds

#         if self.conv_dir_cls is not None:
#             dir_cls_preds = self.conv_dir_cls(spatial_features_2d)
#             dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
#             self.forward_ret_dict['dir_cls_preds'] = dir_cls_preds
#         else:
#             dir_cls_preds = None

#         if self.training:
#             targets_dict = self.assign_targets(
#                 gt_boxes=data_dict['gt_boxes']
#             )
#             self.forward_ret_dict.update(targets_dict)

#         if not self.training or self.predict_boxes_when_training:
#             batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
#                 batch_size=data_dict['batch_size'],
#                 cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
#             )
#             data_dict['batch_cls_preds'] = batch_cls_preds
#             data_dict['batch_box_preds'] = batch_box_preds
#             data_dict['cls_preds_normalized'] = False

#         return data_dict
